import json
import logging
import os
import shutil
from collections import OrderedDict
from typing import List, Dict, Tuple, Iterable, Type
from zipfile import ZipFile
from inspect import getmembers, isfunction

import numpy as np
import pytorch_transformers
import torch
from numpy import ndarray
from torch import nn, Tensor
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from tqdm import tqdm, trange

from . import __DOWNLOAD_SERVER__
from .evaluation import SentenceEvaluator
from .util import import_from_string, batch_to_device, http_get
from . import __version__

class SentenceTransformer(nn.Sequential):
    def __init__(self, model_name_or_path: str = None, modules: Iterable[nn.Module] = None, device: str = None):
        if modules is not None and not isinstance(modules, OrderedDict):
            modules = OrderedDict([(str(idx), module) for idx, module in enumerate(modules)])

        if model_name_or_path is not None and model_name_or_path != "":
            logging.info("Load pretrained SentenceTransformer: {}".format(model_name_or_path))

            if '/' not in model_name_or_path and '\\' not in model_name_or_path and not os.path.isdir(model_name_or_path):
                logging.info("Did not find a / or \\ in the name. Assume to download model from server")
                model_name_or_path = __DOWNLOAD_SERVER__ + model_name_or_path + '.zip'

            if model_name_or_path.startswith('http://') or model_name_or_path.startswith('https://'):
                model_url = model_name_or_path
                folder_name = model_url.replace("https://", "").replace("http://", "").replace("/", "_")[:250]

                try:
                    from torch.hub import _get_torch_home
                    torch_cache_home = _get_torch_home()
                except ImportError:
                    torch_cache_home = os.path.expanduser(
                        os.getenv('TORCH_HOME', os.path.join(
                            os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch')))
                default_cache_path = os.path.join(torch_cache_home, 'sentence_transformers')
                model_path = os.path.join(default_cache_path, folder_name)
                os.makedirs(model_path, exist_ok=True)


                if not os.listdir(model_path):
                    if model_url[-1] == "/":
                        model_url = model_url[:-1]
                    logging.info("Downloading sentence transformer model from {} and saving it at {}".format(model_url, model_path))
                    try:
                        zip_save_path = os.path.join(model_path, 'model.zip')
                        http_get(model_url, zip_save_path)
                        with ZipFile(zip_save_path, 'r') as zip:
                            zip.extractall(model_path)
                    except Exception as e:
                        shutil.rmtree(model_path)
                        raise e
            else:
                model_path = model_name_or_path

            #### Load from disk
            if model_path is not None:
                logging.info("Load SentenceTransformer from folder: {}".format(model_path))
                with open(os.path.join(model_path, 'modules.json')) as fIn:
                    contained_modules = json.load(fIn)

                modules = OrderedDict()
                for module_config in contained_modules:
                    module_class = import_from_string(module_config['type'])
                    module = module_class.load(os.path.join(model_path, module_config['path']))
                    modules[module_config['name']] = module


        super().__init__(modules)
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
            logging.info("Use pytorch device: {}".format(device))
        self.device = torch.device(device)
        self.to(device)

    def encode(self, sentences: List[str], batch_size: int = 8, show_progress_bar: bool = None, token_vecs: bool=False) -> List[ndarray]:
        """
        Computes sentence embeddings
        :param sentences:
           the sentences to embed
        :param batch_size:
           the batch size used for the computation
        :param show_progress_bar:
            Output a progress bar when encode sentences
        :return:
           a list with ndarrays of the embeddings for each sentence
        """
        if show_progress_bar is None:
            show_progress_bar = (logging.getLogger().getEffectiveLevel()==logging.INFO or logging.getLogger().getEffectiveLevel()==logging.DEBUG)

        all_embeddings = []
        all_token_embeddings = []
        all_tokens = []
        length_sorted_idx = np.argsort([len(sen) for sen in sentences])#将语句从小到大排序了
        #length_sorted_idx存放了不同语句按照长度排序后对应的下标 之后遍历便使用这种顺序
        #length_sorted_idx[0]是最短语句对应下标 sentences[length_sorted_idx[0]]是最短语句
        iterator = range(0, len(sentences), batch_size)
        if show_progress_bar:
            iterator = tqdm(iterator, desc="Batches")
        #分batch进行处理  即一个batch进行对齐
        for batch_idx in iterator:
            batch_tokens = []

            batch_start = batch_idx
            batch_end = min(batch_start + batch_size, len(sentences))#防止越界
            #用于对齐 这里将语句tokenize化之后计算长度
            #
            longest_seq = 0
            for idx in length_sorted_idx[batch_start: batch_end]:
                sentence = sentences[idx]
                tokens = self.tokenize(sentence)#将
                longest_seq = max(longest_seq, len(tokens))
                batch_tokens.append(tokens)
            #下面对tokenize后的语句进行特征提取
            #开头添加了起始标志101  统一用0填充到最大长度
            features = {}#记录不同语句特征对应的值
            for text in batch_tokens:
                sentence_features = self.get_sentence_features(text, longest_seq)
                #sentence_features包含内容？
                for feature_name in sentence_features:#feature_name是key
                    if feature_name not in features:
                        features[feature_name] = []
                    features[feature_name].append(sentence_features[feature_name])
            #这里转换value为tensor
            for feature_name in features:
                features[feature_name] = torch.tensor(np.asarray(features[feature_name])).to(self.device)
            #存储嵌入结果
            with torch.no_grad():#这部分不进行梯度计算 不进行反向传播
                embeddings = self.forward(features)#得到每个语句的嵌入层表示
                #embeddings相比于features 多了sentence token cls_token的嵌入表示
                sent_embeddings = embeddings['sentence_embedding'].to('cpu').numpy()
                #sent_embeddings是每一句话嵌入式编码后的结果 8*1024
                all_embeddings.extend(sent_embeddings)
                raw_token_vecs = embeddings['token_embeddings'].to('cpu').numpy()
                token_ids = embeddings['input_ids'].to('cpu').numpy()
                for i,l in enumerate(embeddings['sentence_lengths'].to('cpu').numpy()):
                    all_token_embeddings.append(raw_token_vecs[i][:l])
                    all_tokens.append(self._first_module().ids_to_tokens(token_ids[i][:l]))
        #由于编码时需要对齐，因此按照长度排序并且以此顺序存储了嵌入表达
        #返回时需要再次转换顺序，和输入的语句顺序一一对应 因此这里再次进行argsort 相当于还原
        reverting_order = np.argsort(length_sorted_idx)
        #reverting_order[0]是
        all_embeddings = [all_embeddings[idx] for idx in reverting_order]
        all_token_embeddings = [all_token_embeddings[idx] for idx in reverting_order]
        all_tokens = [all_tokens[idx] for idx in reverting_order]

        if token_vecs:
            return all_token_embeddings, all_tokens
        else:
            return all_embeddings

    def tokenize(self, text):
        return self._first_module().tokenize(text)

    def list_functions(self):
        functions_list = [o for o in getmembers(self._first_module()) if isfunction(o[1])]
        print(functions_list)

    def get_sentence_features(self, *features):
        return self._first_module().get_sentence_features(*features)

    def get_sentence_embedding_dimension(self):
        return self._last_module().get_sentence_embedding_dimension()

    def _first_module(self):
        """Returns the first module of this sequential embedder"""
        return self._modules[next(iter(self._modules))]

    def _last_module(self):
        """Returns the last module of this sequential embedder"""
        return self._modules[next(reversed(self._modules))]

    def save(self, path):
        """
        Saves all elements for this seq. sentence embedder into different sub-folders
        """
        logging.info("Save model to {}".format(path))
        contained_modules = []

        for idx, name in enumerate(self._modules):
            module = self._modules[name]
            model_path = os.path.join(path, str(idx)+"_"+type(module).__name__)
            os.makedirs(model_path, exist_ok=True)
            module.save(model_path)
            contained_modules.append({'idx': idx, 'name': name, 'path': os.path.basename(model_path), 'type': type(module).__module__})

        with open(os.path.join(path, 'modules.json'), 'w') as fOut:
            json.dump(contained_modules, fOut, indent=2)

        with open(os.path.join(path, 'config.json'), 'w') as fOut:
            json.dump({'__version__': __version__}, fOut, indent=2)

    def smart_batching_collate(self, batch):
        """
        Transforms a batch from a SmartBatchingDataset to a batch of tensors for the model

        :param batch:
            a batch from a SmartBatchingDataset
        :return:
            a batch of tensors for the model
        """
        num_texts = len(batch[0][0])

        labels = []
        paired_texts = [[] for _ in range(num_texts)]
        max_seq_len = [0] * num_texts
        for tokens, label in batch:
            labels.append(label)
            for i in range(num_texts):
                paired_texts[i].append(tokens[i])
                max_seq_len[i] = max(max_seq_len[i], len(tokens[i]))

        features = []
        for idx in range(num_texts):
            max_len = max_seq_len[idx]
            feature_lists = {}
            for text in paired_texts[idx]:
                sentence_features = self.get_sentence_features(text, max_len)

                for feature_name in sentence_features:
                    if feature_name not in feature_lists:
                        feature_lists[feature_name] = []
                    feature_lists[feature_name].append(sentence_features[feature_name])

            for feature_name in feature_lists:
                feature_lists[feature_name] = torch.tensor(np.asarray(feature_lists[feature_name]))

            features.append(feature_lists)

        return {'features': features, 'labels': torch.stack(labels)}



    def fit(self,
            train_objectives: Iterable[Tuple[DataLoader, nn.Module]],
            evaluator: SentenceEvaluator,
            epochs: int = 1,
            scheduler: str = 'WarmupLinear',
            warmup_steps: int = 10000,
            optimizer_class: Type[Optimizer] = pytorch_transformers.AdamW,
            optimizer_params : Dict[str, object ]= {'lr': 2e-5, 'eps': 1e-6, 'correct_bias': False},
            weight_decay: float = 0.01,
            evaluation_steps: int = 0,
            output_path: str = None,
            save_best_model: bool = True,
            max_grad_norm: float = 1,
            fp16: bool = False,
            fp16_opt_level: str = '01',
            local_rank: int = -1
            ):
        """
        Train the model with the given training objective

        Each training objective is sampled in turn for one batch.
        We sample only as many batches from each objective as there are in the smallest one
        to make sure of equal training with each dataset.

        :param weight_decay:
        :param scheduler:
        :param warmup_steps:
        :param optimizer:
        :param evaluation_steps:
        :param output_path:
        :param save_best_model:
        :param max_grad_norm:
        :param fp16:
        :param fp16_opt_level:
        :param local_rank:
        :param train_objectives:
            Tuples of DataLoader and LossConfig
        :param evaluator:
        :param epochs:
        """
        if output_path is not None:
            os.makedirs(output_path, exist_ok=True)
            if os.listdir(output_path):
                raise ValueError("Output directory ({}) already exists and is not empty.".format(
                    output_path))

        dataloaders = [dataloader for dataloader, _ in train_objectives]

        # Use smart batching
        for dataloader in dataloaders:
            dataloader.collate_fn = self.smart_batching_collate

        loss_models = [loss for _, loss in train_objectives]
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        for loss_model in loss_models:
            loss_model.to(device)

        self.best_score = -9999

        min_batch_size = min([len(dataloader) for dataloader in dataloaders])
        num_train_steps = int(min_batch_size * epochs)

        # Prepare optimizers
        optimizers = []
        schedulers = []
        for loss_model in loss_models:
            param_optimizer = list(loss_model.named_parameters())
            no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
            optimizer_grouped_parameters = [
                {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay},
                {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
            ]
            t_total = num_train_steps
            if local_rank != -1:
                t_total = t_total // torch.distributed.get_world_size()

            optimizer = optimizer_class(optimizer_grouped_parameters, **optimizer_params)
            scheduler = self._get_scheduler(optimizer, scheduler=scheduler, warmup_steps=warmup_steps, t_total=t_total)

            optimizers.append(optimizer)
            schedulers.append(scheduler)

        if fp16:
            try:
                from apex import amp
            except ImportError:
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")

            for idx in range(len(loss_models)):
                model, optimizer = amp.initialize(loss_models[idx], optimizers[idx], opt_level=fp16_opt_level)
                loss_models[idx] = model
                optimizers[idx] = optimizer

        global_step = 0
        data_iterators = [iter(dataloader) for dataloader in dataloaders]

        num_train_objectives = len(train_objectives)
        for epoch in trange(epochs, desc="Epoch"):
            training_steps = 0

            for loss_model in loss_models:
                loss_model.zero_grad()
                loss_model.train()

            for step in trange(num_train_objectives * min_batch_size, desc="Iteration"):
                idx = step % num_train_objectives

                loss_model = loss_models[idx]
                optimizer = optimizers[idx]
                scheduler = schedulers[idx]
                data_iterator = data_iterators[idx]

                try:
                    data = next(data_iterator)
                except StopIteration:
                    logging.info("Restart data_iterator")
                    data_iterator = iter(dataloaders[idx])
                    data_iterators[idx] = data_iterator
                    data = next(data_iterator)

                features, labels = batch_to_device(data, self.device)
                loss_value = loss_model(features, labels)

                if fp16:
                    with amp.scale_loss(loss_value, optimizer) as scaled_loss:
                        scaled_loss.backward()
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_grad_norm)
                else:
                    loss_value.backward()
                    torch.nn.utils.clip_grad_norm_(loss_model.parameters(), max_grad_norm)

                training_steps += 1

                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                global_step += 1

                if evaluation_steps > 0 and training_steps % evaluation_steps == 0:
                    self._eval_during_training(evaluator, output_path, save_best_model, epoch, training_steps)
                    for loss_model in loss_models:
                        loss_model.zero_grad()
                        loss_model.train()

            self._eval_during_training(evaluator, output_path, save_best_model, epoch, -1)

    def evaluate(self, evaluator: SentenceEvaluator, output_path: str = None):
        """
        Evaluate the model

        :param evaluator:
            the evaluator
        :param output_path:
            the evaluator can write the results to this path
        """
        if output_path is not None:
            os.makedirs(output_path, exist_ok=True)
        return evaluator(self, output_path)

    def _eval_during_training(self, evaluator, output_path, save_best_model, epoch, steps):
        """Runs evaluation during the training"""
        if evaluator is not None:
            score = evaluator(self, output_path=output_path, epoch=epoch, steps=steps)
            if score > self.best_score and save_best_model:
                self.save(output_path)
                self.best_score = score


    def _get_scheduler(self, optimizer, scheduler: str, warmup_steps: int, t_total: int):
        """
        Returns the correct learning rate scheduler
        """
        scheduler = scheduler.lower()
        if scheduler == 'constantlr':
            return pytorch_transformers.ConstantLRSchedule(optimizer)
        elif scheduler == 'warmupconstant':
            return pytorch_transformers.WarmupConstantSchedule(optimizer, warmup_steps=warmup_steps)
        elif scheduler == 'warmuplinear':
            return pytorch_transformers.WarmupLinearSchedule(optimizer, warmup_steps=warmup_steps, t_total=t_total)
        elif scheduler == 'warmupcosine':
            return pytorch_transformers.WarmupCosineSchedule(optimizer, warmup_steps=warmup_steps, t_total=t_total)
        elif scheduler == 'warmupcosinewithhardrestarts':
            return pytorch_transformers.WarmupCosineWithHardRestartsSchedule(optimizer, warmup_steps=warmup_steps, t_total=t_total)
        else:
            raise ValueError("Unknown scheduler {}".format(scheduler))
